Graph Data Embedding¶
- NetworkX for creating the graph.
- Create the sample graph: Nodes are classified as either "Birds" or "Timber", with edges representing the "origin" relationships.
- create_sample_graph: This function creates a sample graph with two categories: "Birds" (e.g., eagle, parrot, sparrow) and "Timber" (e.g., oak, cedar, maple), with edges representing the "origin" relationship.
- Transformers from Hugging Face for generating embeddings.
- Generate embeddings: Use a transformer-based model (like BERT or a domain-specific model) to create embeddings for each node in the graph.
- get_bert_embeddings: This function generates embeddings using a pre-trained BERT model. The embeddings for each node (bird or timber) are created by tokenizing the node name and using the mean of the last hidden states.
- create_embeddings_for_graph: This function generates embeddings for all nodes in the graph.
- Generate embeddings: Use a transformer-based model (like BERT or a domain-specific model) to create embeddings for each node in the graph.
- Faiss for storing and querying the embeddings.
- Store embeddings in Faiss: The embeddings are indexed using Faiss for efficient similarity search.
- User Query: The system will retrieve the top 3 most similar graph nodes based on the user's query.
- store_embeddings_in_faiss: This function stores the embeddings in a FAISS index to enable efficient similarity search.
- retrieve_top_k_similar: This function retrieves the top-k most similar nodes to a given query using FAISS.
- Matplotlib and NetworkX for displaying graph structures.
- Display Results: The original graph and the predicted graph data are displayed in tables, and the network is visualized.
In [ ]:
%pip install -q networkx faiss-cpu transformers pandas matplotlib torch
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages. ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. petastorm 0.12.1 requires pyspark>=2.1.0, which is not installed. databricks-feature-store 0.14.3 requires pyspark<4,>=3.1.2, which is not installed. ydata-profiling 4.2.0 requires numpy<1.24,>=1.16.0, but you have numpy 2.1.3 which is incompatible. scipy 1.9.1 requires numpy<1.25.0,>=1.18.5, but you have numpy 2.1.3 which is incompatible. numba 0.55.1 requires numpy<1.22,>=1.18, but you have numpy 2.1.3 which is incompatible. mleap 0.20.0 requires scikit-learn<0.23.0,>=0.22.0, but you have scikit-learn 1.1.1 which is incompatible. langchain 0.0.217 requires numpy<2,>=1, but you have numpy 2.1.3 which is incompatible. databricks-feature-store 0.14.3 requires numpy<2,>=1.19.2, but you have numpy 2.1.3 which is incompatible. Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
In [ ]:
import networkx as nx
import numpy as np
import pandas as pd
import faiss
from transformers import BertTokenizer, BertModel
import torch
import matplotlib.pyplot as plt
import textwrap
# 1. Create the sample graph data with two categories (Birds and Timber)
def create_sample_graph():
G = nx.Graph()
# Add bird nodes and their origin edges
birds = ["eagle", "parrot", "sparrow", "emu"]
bird_origins = {"eagle": "USA", "parrot": "Australia", "sparrow": "Europe", "emu": "Australia"}
for bird in birds:
G.add_node(bird, category="bird", origin=bird_origins[bird])
# Add timber nodes and their origin edges
timbers = ["oak", "cedar", "maple", "tasmanian-oak"]
timber_origins = {"oak": "USA", "cedar": "Canada", "maple": "Canada", "tasmanian-oak" : "Australia"}
for timber in timbers:
G.add_node(timber, category="timber", origin=timber_origins[timber])
# Add edges representing origin relationships
G.add_edge("eagle", "parrot", relation="origin")
G.add_edge("parrot", "sparrow", relation="origin")
G.add_edge("emu", "parrot", relation="origin")
G.add_edge("oak", "cedar", relation="origin")
G.add_edge("cedar", "maple", relation="origin")
G.add_edge("tasmanian-oak", "oak", relation="origin")
return G, bird_origins, timber_origins
# 2. Generate BERT-based embeddings for the graph nodes
def get_bert_embeddings(texts):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Use the mean of the last layer hidden states as embeddings
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.numpy()
# 3. Create embeddings for graph nodes
def create_embeddings_for_graph(G):
node_names = list(G.nodes)
node_texts = [f"{node} ({G.nodes[node]['category']} from {G.nodes[node]['origin']})" for node in node_names]
# Generate embeddings for each node using BERT
embeddings = get_bert_embeddings(node_texts)
# Store embeddings in a DataFrame for easy inspection
embedding_df = pd.DataFrame(embeddings, index=node_names, columns=[f"dim_{i}" for i in range(embeddings.shape[1])])
return embedding_df, embeddings
# 4. Store embeddings in FAISS index for similarity search
def store_embeddings_in_faiss(embeddings):
dim = embeddings.shape[1] # Dimensionality of embeddings
index = faiss.IndexFlatL2(dim) # Use L2 distance metric
index.add(embeddings.astype(np.float32)) # Add embeddings to FAISS index
return index
# 5. Retrieve the top k most similar nodes based on a query
def retrieve_top_k_similar(index, query_embedding, k=2):
D, I = index.search(query_embedding.astype(np.float32), k)
return I, D
# Function to wrap text inside nodes for better visualization
def wrap_text(label, width=10):
return "\n".join(textwrap.wrap(label, width))
# 6. Visualize the graph using NetworkX (with edges showing the origin information)
def visualize_graph_side_by_side(G, birds, timbers, title="Graph"):
fig, axes = plt.subplots(1, 2, figsize=(14, 7)) # 1 row, 2 columns for side-by-side plots
# Generate subgraphs for birds and timbers
G_birds = G.subgraph(birds)
G_timbers = G.subgraph(timbers)
# First subplot for Birds
ax1 = axes[0]
pos_birds = nx.spring_layout(G_birds, seed=42) # Generate positions for the subgraph
# Wrap node labels for better readability
labels_birds = {node: wrap_text(node) for node in G_birds.nodes()}
# Draw the graph with wrapped labels
nx.draw(G_birds, pos_birds, with_labels=True, labels=labels_birds, node_color="lightblue", font_weight="bold", node_size=3000, font_size=12, ax=ax1)
edge_labels_birds = { (u, v): f"{G.nodes[u]['origin']} -> {G.nodes[v]['origin']}" for u, v in G_birds.edges() }
nx.draw_networkx_edge_labels(G_birds, pos_birds, edge_labels=edge_labels_birds, font_size=12, font_color='red', ax=ax1)
ax1.set_title("Birds")
# Second subplot for Timbers
ax2 = axes[1]
pos_timbers = nx.spring_layout(G_timbers, seed=42) # Generate positions for the subgraph
# Wrap node labels for better readability
labels_timbers = {node: wrap_text(node) for node in G_timbers.nodes()}
# Draw the graph with wrapped labels
nx.draw(G_timbers, pos_timbers, with_labels=True, labels=labels_timbers, node_color="lightgreen", font_weight="bold", node_size=3000, font_size=12, ax=ax2)
edge_labels_timbers = { (u, v): f"{G.nodes[u]['origin']} -> {G.nodes[v]['origin']}" for u, v in G_timbers.edges() }
nx.draw_networkx_edge_labels(G_timbers, pos_timbers, edge_labels=edge_labels_timbers, font_size=12, font_color='red', ax=ax2)
ax2.set_title("Timbers")
plt.tight_layout()
plt.show()
# Main function to simulate the full process
def main():
# 1. Create the sample graph
G, bird_origins, timber_origins = create_sample_graph()
# 2. Generate embeddings for the graph
embedding_df, embeddings = create_embeddings_for_graph(G)
# Display original graph data with first 2 dimensions of embeddings
print("Graph Data (first 2 embeddings of nodes):")
# Create a DataFrame to display the graph edges and their relations
original_df = pd.DataFrame(list(G.edges(data=True)), columns=["Node1", "Node2", "Relation"])
# Map origin information to the Relation column based on nodes
def get_relation_value(node1, node2):
# Check if both nodes are from the bird category or timber category
if node1 in bird_origins and node2 in bird_origins:
return bird_origins.get(node1, "Unknown") + " -> " + bird_origins.get(node2, "Unknown")
elif node1 in timber_origins and node2 in timber_origins:
return timber_origins.get(node1, "Unknown") + " -> " + timber_origins.get(node2, "Unknown")
else:
return "Unknown"
# Apply the get_relation_value function to each edge
original_df["Relation"] = original_df.apply(lambda row: get_relation_value(row["Node1"], row["Node2"]), axis=1)
# Add only the first 2 dimensions of embeddings to original_df
node_embeddings = {node: embeddings[i][:2] for i, node in enumerate(G.nodes)}
# Then, for each row in original_df, map the nodes to their embeddings
original_df['Node1 Embedding'] = original_df['Node1'].map(node_embeddings)
original_df['Node2 Embedding'] = original_df['Node2'].map(node_embeddings)
# Add the index as the first column in original_df
original_df.reset_index(inplace=True)
original_df.rename(columns={'index': 'Index'}, inplace=True)
print(original_df.to_markdown(index=False))
# 6. Visualize the graph (Birds and Timbers side by side)
print("\nVisualising of Nodes and Edges") # Birds and Timbers Graph...
birds = [n for n, d in G.nodes(data=True) if d['category'] == 'bird']
timbers = [n for n, d in G.nodes(data=True) if d['category'] == 'timber']
visualize_graph_side_by_side(G, birds, timbers, title="Birds and Timbers")
# 3. Store embeddings in FAISS
index = store_embeddings_in_faiss(embeddings)
# 4. Simulate a user query (e.g., querying for "parrot")
user_query = "emu"
query_embedding = get_bert_embeddings([user_query])
# 5. Retrieve the top 3 similar nodes based on the user query
indices, distances = retrieve_top_k_similar(index, query_embedding, k=3)
# Retrieve the predicted nodes based on indices
predicted_nodes = [list(G.nodes)[i] for i in indices[0]]
# Display user query and predicted graph data (top 3 similar nodes)
print(f"\nUser Query: {user_query} ")
predicted_df = pd.DataFrame({
"Index": indices[0], # Include index in predicted_df
"Predicted Node": predicted_nodes,
"Embedding Dimension (first 2 embeddings)": [embeddings[i][:2] for i in indices[0]] # Slice the first 5 dimensions
})
# 7. Add the Relation column to predicted_df based on origin information
def get_predicted_relation(node):
if node in bird_origins:
return bird_origins[node]
elif node in timber_origins:
return timber_origins[node]
else:
return "Unknown"
# Populate the 'Relation' column
predicted_df['Relation'] = predicted_df['Predicted Node'].map(get_predicted_relation)
# Sort predicted_df by 'Index' column in ascending order
predicted_df = predicted_df.sort_values(by="Index", ascending=True)
print(predicted_df.to_markdown(index=False))
# 6. Visualize the predicted graph (subgraph of predicted nodes)
print("\nVisualizing Predicted Node Graph")
predicted_G = G.subgraph(predicted_nodes)
visualize_graph_side_by_side(predicted_G, birds, timbers, title="Predicted Graph")
if __name__ == "__main__":
main()
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Graph Data (first 2 embeddings of nodes): | Index | Node1 | Node2 | Relation | Node1 Embedding | Node2 Embedding | |--------:|:--------|:--------------|:-----------------------|:--------------------------|:--------------------------| | 0 | eagle | parrot | USA -> Australia | [-0.08763738 -0.18008912] | [-0.21244605 -0.08693982] | | 1 | parrot | sparrow | Australia -> Europe | [-0.21244605 -0.08693982] | [-0.30793023 -0.19978818] | | 2 | parrot | emu | Australia -> Australia | [-0.21244605 -0.08693982] | [-0.28172386 -0.08433155] | | 3 | oak | cedar | USA -> Canada | [-0.19053563 -0.02666191] | [-0.15451434 0.1392495 ] | | 4 | oak | tasmanian-oak | USA -> Australia | [-0.19053563 -0.02666191] | [-0.11889555 0.14518885] | | 5 | cedar | maple | Canada -> Canada | [-0.15451434 0.1392495 ] | [-0.17470792 0.18844433] | Visualising of Nodes and Edges
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
User Query: emu | Index | Predicted Node | Embedding Dimension (first 2 embeddings) | Relation | |--------:|:-----------------|:-------------------------------------------|:-----------| | 1 | parrot | [-0.21244605 -0.08693982] | Australia | | 3 | emu | [-0.28172386 -0.08433155] | Australia | | 4 | oak | [-0.19053563 -0.02666191] | USA | Visualizing Predicted Node Graph